/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"
#include "gcgraph.hpp"
#include <limits>

using namespace cv;

/*
This is implementation of image segmentation algorithm GrabCut described in
"GrabCut — Interactive Foreground Extraction using Iterated Graph Cuts".
Carsten Rother, Vladimir Kolmogorov, Andrew Blake.
 */

/*
 GMM - Gaussian Mixture Model
*/
class GMM {
public:
	static const int componentsCount = 5;

	GMM( Mat& _model );
	float operator()( Vec3f color ) const;
	float operator()( int ci, Vec3f color ) const;
	int whichComponent( Vec3f color ) const;

	void initLearning();
	void addSample( int ci, Vec3f color );
	void endLearning();
private:
	void calcInverseCovAndDeterm( int ci );
	Mat model;
	float* coefs;
	float* mean;
	float* cov;

	float inverseCovs[componentsCount][3][3];
	float covDeterms[componentsCount];

	float sums[componentsCount][3];
	float prods[componentsCount][3][3];
	int sampleCounts[componentsCount];
	int totalSampleCount;
};

GMM::GMM( Mat& _model ) {
	const int modelSize = 3/*mean*/ + 9/*covariance*/ + 1/*component weight*/;
	if ( _model.empty() ) {
		_model.create( 1, modelSize * componentsCount, CV_32FC1 );
		_model.setTo(Scalar(0));
	} else if ( (_model.type() != CV_32FC1) || (_model.rows != 1) || (_model.cols != modelSize * componentsCount) ) {
		CV_Error( CV_StsBadArg, "_model must have CV_32FC1 type, rows == 1 and cols == 13*componentsCount" );
	}

	model = _model;

	coefs = model.ptr<float>(0);
	mean = coefs + componentsCount;
	cov = mean + 3 * componentsCount;

	for ( int ci = 0; ci < componentsCount; ci++ )
		if ( coefs[ci] > 0 ) {
			calcInverseCovAndDeterm( ci );
		}
}

float GMM::operator()( Vec3f color ) const {
	float res = 0;
	for ( int ci = 0; ci < componentsCount; ci++ ) {
		res += coefs[ci] * (*this)(ci, color );
	}
	return res;
}

float GMM::operator()( int ci, Vec3f color ) const {
	float res = 0;
	if ( coefs[ci] > 0 ) {
		if ( covDeterms[ci] > std::numeric_limits<float>::epsilon() ) {
			Vec3f diff = color;
			float* m = mean + 3 * ci;
			diff[0] -= m[0]; diff[1] -= m[1]; diff[2] -= m[2];
			float mult = diff[0] * (diff[0] * inverseCovs[ci][0][0] + diff[1] * inverseCovs[ci][1][0] + diff[2] * inverseCovs[ci][2][0])
						 + diff[1] * (diff[0] * inverseCovs[ci][0][1] + diff[1] * inverseCovs[ci][1][1] + diff[2] * inverseCovs[ci][2][1])
						 + diff[2] * (diff[0] * inverseCovs[ci][0][2] + diff[1] * inverseCovs[ci][1][2] + diff[2] * inverseCovs[ci][2][2]);
			res = 1.0f / sqrt(covDeterms[ci]) * exp(-0.5f * mult);
		}
	}
	return res;
}

int GMM::whichComponent( Vec3f color ) const {
	int k = 0;
	float max = 0;

	for ( int ci = 0; ci < componentsCount; ci++ ) {
		float p = (*this)( ci, color );
		if ( p > max ) {
			k = ci;
			max = p;
		}
	}
	return k;
}

void GMM::initLearning() {
	for ( int ci = 0; ci < componentsCount; ci++) {
		sums[ci][0] = sums[ci][1] = sums[ci][2] = 0;
		prods[ci][0][0] = prods[ci][0][1] = prods[ci][0][2] = 0;
		prods[ci][1][0] = prods[ci][1][1] = prods[ci][1][2] = 0;
		prods[ci][2][0] = prods[ci][2][1] = prods[ci][2][2] = 0;
		sampleCounts[ci] = 0;
	}
	totalSampleCount = 0;
}

void GMM::addSample( int ci, Vec3f color ) {
	sums[ci][0] += color[0]; sums[ci][1] += color[1]; sums[ci][2] += color[2];
	prods[ci][0][0] += color[0] * color[0]; prods[ci][0][1] += color[0] * color[1]; prods[ci][0][2] += color[0] * color[2];
	prods[ci][1][0] += color[1] * color[0]; prods[ci][1][1] += color[1] * color[1]; prods[ci][1][2] += color[1] * color[2];
	prods[ci][2][0] += color[2] * color[0]; prods[ci][2][1] += color[2] * color[1]; prods[ci][2][2] += color[2] * color[2];
	sampleCounts[ci]++;
	totalSampleCount++;
}

void GMM::endLearning() {
	for ( int ci = 0; ci < componentsCount; ci++ ) {
		int n = sampleCounts[ci];
		if ( n == 0 ) {
			coefs[ci] = 0;
		} else {
			coefs[ci] = (float)n / totalSampleCount;

			float* m = mean + 3 * ci;
			m[0] = sums[ci][0] / n; m[1] = sums[ci][1] / n; m[2] = sums[ci][2] / n;

			float* c = cov + 9 * ci;
			c[0] = prods[ci][0][0] / n - m[0] * m[0]; c[1] = prods[ci][0][1] / n - m[0] * m[1]; c[2] = prods[ci][0][2] / n - m[0] * m[2];
			c[3] = prods[ci][1][0] / n - m[1] * m[0]; c[4] = prods[ci][1][1] / n - m[1] * m[1]; c[5] = prods[ci][1][2] / n - m[1] * m[2];
			c[6] = prods[ci][2][0] / n - m[2] * m[0]; c[7] = prods[ci][2][1] / n - m[2] * m[1]; c[8] = prods[ci][2][2] / n - m[2] * m[2];

			calcInverseCovAndDeterm(ci);
		}
	}
}

void GMM::calcInverseCovAndDeterm( int ci ) {
	if ( coefs[ci] > 0 ) {
		float* c = cov + 9 * ci;
		float dtrm =
			covDeterms[ci] = c[0] * (c[4] * c[8] - c[5] * c[7]) - c[1] * (c[3] * c[8] - c[5] * c[6]) + c[2] * (c[3] * c[7] - c[4] * c[6]);

		if ( dtrm > std::numeric_limits<float>::epsilon() ) {
			inverseCovs[ci][0][0] =  (c[4] * c[8] - c[5] * c[7]) / dtrm;
			inverseCovs[ci][1][0] = -(c[3] * c[8] - c[5] * c[6]) / dtrm;
			inverseCovs[ci][2][0] =  (c[3] * c[7] - c[4] * c[6]) / dtrm;
			inverseCovs[ci][0][1] = -(c[1] * c[8] - c[2] * c[7]) / dtrm;
			inverseCovs[ci][1][1] =  (c[0] * c[8] - c[2] * c[6]) / dtrm;
			inverseCovs[ci][2][1] = -(c[0] * c[7] - c[1] * c[6]) / dtrm;
			inverseCovs[ci][0][2] =  (c[1] * c[5] - c[2] * c[4]) / dtrm;
			inverseCovs[ci][1][2] = -(c[0] * c[5] - c[2] * c[3]) / dtrm;
			inverseCovs[ci][2][2] =  (c[0] * c[4] - c[1] * c[3]) / dtrm;
		}
	}
}

/*
  Calculate beta - parameter of GrabCut algorithm.
  beta = 1/(2*avg(sqr(||color[i] - color[j]||)))
*/
float calcBeta( const Mat& img ) {
	float beta = 0;
	for ( int y = 0; y < img.rows; y++ ) {
		for ( int x = 0; x < img.cols; x++ ) {
			Vec3f color = img.at<Vec3b>(y, x);
			if ( x > 0 ) { // left
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y, x - 1);
				beta += diff.dot(diff);
			}
			if ( y > 0 && x > 0 ) { // upleft
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y - 1, x - 1);
				beta += diff.dot(diff);
			}
			if ( y > 0 ) { // up
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y - 1, x);
				beta += diff.dot(diff);
			}
			if ( y > 0 && x < img.cols - 1) { // upright
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y - 1, x + 1);
				beta += diff.dot(diff);
			}
		}
	}
	beta = 1.f / (2 * beta / (4 * img.cols * img.rows - 3 * img.cols - 3 * img.rows + 2) );
	return beta;
}

/*
  Calculate weights of noterminal vertices of graph.
  beta and gamma - parameters of GrabCut algorithm.
 */
void calcNWeights( const Mat& img, Mat& leftW, Mat& upleftW, Mat& upW, Mat& uprightW, float beta, float gamma ) {
	const float gammaDivSqrt2 = gamma / std::sqrt(2.0f);
	leftW.create( img.rows, img.cols, CV_32FC1 );
	upleftW.create( img.rows, img.cols, CV_32FC1 );
	upW.create( img.rows, img.cols, CV_32FC1 );
	uprightW.create( img.rows, img.cols, CV_32FC1 );
	for ( int y = 0; y < img.rows; y++ ) {
		for ( int x = 0; x < img.cols; x++ ) {
			Vec3f color = img.at<Vec3b>(y, x);
			if ( x - 1 >= 0 ) { // left
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y, x - 1);
				leftW.at<float>(y, x) = gamma * exp(-beta * diff.dot(diff));
			} else {
				leftW.at<float>(y, x) = 0;
			}
			if ( x - 1 >= 0 && y - 1 >= 0 ) { // upleft
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y - 1, x - 1);
				upleftW.at<float>(y, x) = gammaDivSqrt2 * exp(-beta * diff.dot(diff));
			} else {
				upleftW.at<float>(y, x) = 0;
			}
			if ( y - 1 >= 0 ) { // up
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y - 1, x);
				upW.at<float>(y, x) = gamma * exp(-beta * diff.dot(diff));
			} else {
				upW.at<float>(y, x) = 0;
			}
			if ( x + 1 < img.cols - 1 && y - 1 >= 0 ) { // upright
				Vec3f diff = color - (Vec3f)img.at<Vec3b>(y - 1, x + 1);
				uprightW.at<float>(y, x) = gammaDivSqrt2 * exp(-beta * diff.dot(diff));
			} else {
				uprightW.at<float>(y, x) = 0;
			}
		}
	}
}

/*
  Check size, type and element values of mask matrix.
 */
void checkMask( const Mat& img, const Mat& mask ) {
	if ( mask.empty() ) {
		CV_Error( CV_StsBadArg, "mask is empty" );
	}
	if ( mask.type() != CV_8UC1 ) {
		CV_Error( CV_StsBadArg, "mask must have CV_8UC1 type" );
	}
	if ( mask.cols != img.cols || mask.rows != img.rows ) {
		CV_Error( CV_StsBadArg, "mask must have as many rows and cols as img" );
	}
	for ( int y = 0; y < mask.rows; y++ ) {
		for ( int x = 0; x < mask.cols; x++ ) {
			uchar val = mask.at<uchar>(y, x);
			if ( val != GC_BGD && val != GC_FGD && val != GC_PR_BGD && val != GC_PR_FGD )
				CV_Error( CV_StsBadArg, "mask element value must be equel"
						  "GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD" );
		}
	}
}

/*
  Initialize mask using rectangular.
*/
void initMaskWithRect( Mat& mask, Size imgSize, Rect rect ) {
	mask.create( imgSize, CV_8UC1 );
	mask.setTo( GC_BGD );

	rect.x = max(0, rect.x);
	rect.y = max(0, rect.y);
	rect.width = min(rect.width, imgSize.width - rect.x);
	rect.height = min(rect.height, imgSize.height - rect.y);

	(mask(rect)).setTo( Scalar(GC_PR_FGD) );
}

/*
  Initialize GMM background and foreground models using kmeans algorithm.
*/
void initGMMs( const Mat& img, const Mat& mask, GMM& bgdGMM, GMM& fgdGMM ) {
	const int kMeansItCount = 10;
	const int kMeansType = KMEANS_PP_CENTERS;

	Mat bgdLabels, fgdLabels;
	vector<Vec3f> bgdSamples, fgdSamples;
	Point p;
	for ( p.y = 0; p.y < img.rows; p.y++ ) {
		for ( p.x = 0; p.x < img.cols; p.x++ ) {
			if ( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ) {
				bgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );
			} else { // GC_FGD | GC_PR_FGD
				fgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );
			}
		}
	}
	CV_Assert( !bgdSamples.empty() && !fgdSamples.empty() );
	Mat _bgdSamples( (int)bgdSamples.size(), 3, CV_32FC1, &bgdSamples[0][0] );
	kmeans( _bgdSamples, GMM::componentsCount, bgdLabels,
			TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType, 0 );
	Mat _fgdSamples( (int)fgdSamples.size(), 3, CV_32FC1, &fgdSamples[0][0] );
	kmeans( _fgdSamples, GMM::componentsCount, fgdLabels,
			TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType, 0 );

	bgdGMM.initLearning();
	for ( int i = 0; i < (int)bgdSamples.size(); i++ ) {
		bgdGMM.addSample( bgdLabels.at<int>(i, 0), bgdSamples[i] );
	}
	bgdGMM.endLearning();

	fgdGMM.initLearning();
	for ( int i = 0; i < (int)fgdSamples.size(); i++ ) {
		fgdGMM.addSample( fgdLabels.at<int>(i, 0), fgdSamples[i] );
	}
	fgdGMM.endLearning();
}

/*
  Assign GMMs components for each pixel.
*/
void assignGMMsComponents( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, Mat& compIdxs ) {
	Point p;
	for ( p.y = 0; p.y < img.rows; p.y++ ) {
		for ( p.x = 0; p.x < img.cols; p.x++ ) {
			Vec3f color = img.at<Vec3b>(p);
			compIdxs.at<int>(p) = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ?
								  bgdGMM.whichComponent(color) : fgdGMM.whichComponent(color);
		}
	}
}

/*
  Learn GMMs parameters.
*/
void learnGMMs( const Mat& img, const Mat& mask, const Mat& compIdxs, GMM& bgdGMM, GMM& fgdGMM ) {
	bgdGMM.initLearning();
	fgdGMM.initLearning();
	Point p;
	for ( int ci = 0; ci < GMM::componentsCount; ci++ ) {
		for ( p.y = 0; p.y < img.rows; p.y++ ) {
			for ( p.x = 0; p.x < img.cols; p.x++ ) {
				if ( compIdxs.at<int>(p) == ci ) {
					if ( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ) {
						bgdGMM.addSample( ci, img.at<Vec3b>(p) );
					} else {
						fgdGMM.addSample( ci, img.at<Vec3b>(p) );
					}
				}
			}
		}
	}
	bgdGMM.endLearning();
	fgdGMM.endLearning();
}

/*
  Construct GCGraph
*/
void constructGCGraph( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, float lambda,
					   const Mat& leftW, const Mat& upleftW, const Mat& upW, const Mat& uprightW,
					   GCGraph<float>& graph ) {
	int vtxCount = img.cols * img.rows,
		edgeCount = 2 * (4 * img.cols * img.rows - 3 * (img.cols + img.rows) + 2);
	graph.create(vtxCount, edgeCount);
	Point p;
	for ( p.y = 0; p.y < img.rows; p.y++ ) {
		for ( p.x = 0; p.x < img.cols; p.x++) {
			// add node
			int vtxIdx = graph.addVtx();
			Vec3b color = img.at<Vec3b>(p);

			// set t-weights
			float fromSource, toSink;
			if ( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD ) {
				fromSource = -log( bgdGMM(color) );
				toSink = -log( fgdGMM(color) );
			} else if ( mask.at<uchar>(p) == GC_BGD ) {
				fromSource = 0;
				toSink = lambda;
			} else { // GC_FGD
				fromSource = lambda;
				toSink = 0;
			}
			graph.addTermWeights( vtxIdx, fromSource, toSink );

			// set n-weights
			if ( p.x > 0 ) {
				float w = leftW.at<float>(p);
				graph.addEdges( vtxIdx, vtxIdx - 1, w, w );
			}
			if ( p.x > 0 && p.y > 0 ) {
				float w = upleftW.at<float>(p);
				graph.addEdges( vtxIdx, vtxIdx - img.cols - 1, w, w );
			}
			if ( p.y > 0 ) {
				float w = upW.at<float>(p);
				graph.addEdges( vtxIdx, vtxIdx - img.cols, w, w );
			}
			if ( p.x < img.cols - 1 && p.y > 0 ) {
				float w = uprightW.at<float>(p);
				graph.addEdges( vtxIdx, vtxIdx - img.cols + 1, w, w );
			}
		}
	}
}

/*
  Estimate segmentation using MaxFlow algorithm
*/
void estimateSegmentation( GCGraph<float>& graph, Mat& mask ) {
	graph.maxFlow();
	Point p;
	for ( p.y = 0; p.y < mask.rows; p.y++ ) {
		for ( p.x = 0; p.x < mask.cols; p.x++ ) {
			if ( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD ) {
				if ( graph.inSourceSegment( p.y * mask.cols + p.x /*vertex index*/ ) ) {
					mask.at<uchar>(p) = GC_PR_FGD;
				} else {
					mask.at<uchar>(p) = GC_PR_BGD;
				}
			}
		}
	}
}

void cv::grabCut( const Mat& img, Mat& mask, Rect rect,
				  Mat& bgdModel, Mat& fgdModel,
				  int iterCount, int mode ) {
	if ( img.empty() ) {
		CV_Error( CV_StsBadArg, "image is empty" );
	}
	if ( img.type() != CV_8UC3 ) {
		CV_Error( CV_StsBadArg, "image mush have CV_8UC3 type" );
	}

	GMM bgdGMM( bgdModel ), fgdGMM( fgdModel );
	Mat compIdxs( img.size(), CV_32SC1 );

	if ( mode == GC_INIT_WITH_RECT || mode == GC_INIT_WITH_MASK ) {
		if ( mode == GC_INIT_WITH_RECT ) {
			initMaskWithRect( mask, img.size(), rect );
		} else { // flag == GC_INIT_WITH_MASK
			checkMask( img, mask );
		}
		initGMMs( img, mask, bgdGMM, fgdGMM );
	}

	if ( iterCount <= 0) {
		return;
	}

	if ( mode == GC_EVAL ) {
		checkMask( img, mask );
	}

	const float gamma = 50;
	const float lambda = 9 * gamma;
	const float beta = calcBeta( img );

	Mat leftW, upleftW, upW, uprightW;
	calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma );

	for ( int i = 0; i < iterCount; i++ ) {
		GCGraph<float> graph;
		assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );
		learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );
		constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );
		estimateSegmentation( graph, mask );
	}
}
